import os
import sys
import multiprocessing
import psutil
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed, TimeoutError

from base.context import Context
from base.build_graph import BuildGraph
from base.defines import CompileType
from utils.common_utils import run_notty_command

class BuildPool(object):
    def __init__(self, finished_queue=None, build_statistics=None):
        self.context = Context()
        self.build_graph = BuildGraph()
        self.build_statistics = build_statistics
        self.build_pool = None
        self.build_pool_size = -1
        self.thread_num_limit = 0.80
        self.finished_queue = finished_queue
        self.work_status = False
        self.thread_futures = []
        
    
    def init_build_pool(self):
        if not self.work_status:
            
            self.build_pool_size = self.thread_num_limit * self.max_cpu_count()
            if self.context.BUILD_POOL_SIZE >= 0:
                self.build_pool_size = self.context.BUILD_POOL_SIZE
            self.build_pool = ThreadPoolExecutor(max_workers=self.build_pool_size)
            self.work_status = True
        
    
    def max_cpu_count(self):
        _max_cpu_count = psutil.cpu_count()
        if not _max_cpu_count:
            _max_cpu_count = multiprocessing.cpu_count()
        
        if not _max_cpu_count:
            _max_cpu_count = psutil.cpu_count(logical=False)
        
        return _max_cpu_count
    
    def submit(self, job=None):
        assert job
        # work node: (work_id, cwd, compile_type, parser)
        (work_id, _, _, _) = job
        self.thread_futures.append(self.build_pool.submit(self.ir_build, job))
    
    def wait_for_all_done(self):
        for future in as_completed(self.thread_futures):
            work_id = future.result()
            self.finished_queue.put(work_id)
            self.thread_futures.remove(future)
        
    def ir_build(self, work_node=()):
        (work_id, cwd, compiler_type, parser) = work_node
        command = []
        build_status = True
        if compiler_type == CompileType.GCC or compiler_type == CompileType.CLANG:
            [ast_command, maple_command] = parser.get_maple_compile_command()
            
            ret, std_out, std_err = run_notty_command(command=ast_command, cwd=cwd)
            logging.debug("Replay AST Commamd: %s" % " ".join(ast_command))
            if self.context.ENABLE_DEBUG_MODE:
                print("[INFO] Transform Compile command:\n\t%s\nto AST Command:\n\t%s" % (' '.join(parser.get_original_command()), ' '.join(ast_command)))
            output = os.path.join(self.context.AST_DIR, work_id + self.context.AST_EXT)
            if not os.path.exists(output):
                # the error must happened when execute command
                build_status = False
                self.error_logger(node=work_id, std_out=std_out, std_err=std_err, command=ast_command, cwd=cwd)
            else:
                build_status = True
            
            if not build_status:
                [ast_command, maple_command] = parser.get_maple_compile_command_with_sys_headers()
                ret, std_out, std_err = run_notty_command(command=ast_command, cwd=cwd)
                logging.debug("Replay AST Commamd: %s" % " ".join(ast_command))
                if self.context.ENABLE_DEBUG_MODE:
                    print("[INFO] Transform Compile command:\n\t%s\nto AST Command:\n\t%s" % (' '.join(parser.get_original_command()), ' '.join(ast_command)))
                output = os.path.join(self.context.AST_DIR, work_id + self.context.AST_EXT)
                if not os.path.exists(output):
                    # the error must happened when execute command
                    build_status = False
                    self.error_logger(node=work_id, std_out=std_out, std_err=std_err, command=ast_command, cwd=cwd)
                else:
                    build_status = True
                
            if not build_status:
                [ast_command, maple_command] = parser.get_retry_maple_compile_command()
            
                ret, std_out, std_err = run_notty_command(command=ast_command, cwd=cwd)
                logging.debug("Replay Retry AST Commamd: %s" % " ".join(ast_command))
                if self.context.ENABLE_DEBUG_MODE:
                    print("[INFO] Transform Compile command:\n\t%s\nto AST Command:\n\t%s" % (' '.join(parser.get_original_command()), ' '.join(ast_command)))
                output = os.path.join(self.context.AST_DIR, work_id + self.context.AST_EXT)
                if not os.path.exists(output):
                    # the error must happened when execute command
                    build_status = False
                    self.build_statistics.inc_ast_failed_targets()
                    self.error_logger(node=work_id, std_out=std_out, std_err=std_err, command=ast_command, cwd=cwd)
                    return work_id
                else:
                    build_status = True

            if os.path.exists(output):   
                ret, std_out, std_err = run_notty_command(command=maple_command, cwd=cwd)
                logging.debug("Replay Maple Command: %s" % " ".join(maple_command))
                output = os.path.join(self.context.MAPLE_DIR, work_id + self.context.MAPLE_EXT)
                if not os.path.exists(output):
                    # the error must happened when execute command
                    build_status = False
                    self.build_statistics.inc_mpl_failed_targets()
                    self.error_logger(node=work_id, std_out=std_out, std_err=std_err, command=maple_command, cwd=cwd)
                    return work_id
            
        elif compiler_type == CompileType.AR or compiler_type == CompileType.LD:
            command = parser.get_maple_link_command()
            ret, std_out, std_err = run_notty_command(command=command, cwd=cwd)
            logging.debug("Replay Command: %s" % " ".join(command))
            if self.context.ENABLE_DEBUG_MODE:
                print("[INFO] Transform Link command:\n\t%s\nto Merge Command:\n\t%s" % (' '.join(parser.get_original_command()), ' '.join(command)))
            output = os.path.join(self.context.MAPLE_DIR, work_id + self.context.MAPLE_EXT)
            if not os.path.exists(output):
                build_status = False
                self.build_statistics.inc_mpl_failed_targets()
                self.error_logger(node=work_id, std_out=std_out, std_err=std_err, command=command, cwd=cwd)
                return work_id
        self.build_statistics.inc_succ_targets()
        return work_id
    
    def shutdown(self):
        if self.work_status:
            self.build_pool.shutdown()
            self.work_status = False
    
    def error_logger(self, node="", std_out="", std_err="", command=[], cwd=""):
        err_log_path = os.path.join(self.context.ERROR_DIR, node + ".log")
        content = "Replay failure occurred when build node: %s\n" % node
        content += "Under CWD: %s\n" % cwd
        content += "Command: %s\n" % " ".join(command)
        if std_out:
            content += "STD_OUT:\n\t %s\n" % std_out
            
        if std_err:
            content += "STD_ERR:\n\t %s\n" % std_err
        
        with open(err_log_path, 'a') as f:
            f.write(content)